import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import os
import random
import torchvision.datasets as datasets
import logging
import torchvision.transforms as transforms
logger = logging.getLogger(__name__)
from .utils import unpickle, get_unsupervised_transform, other_class, multiclass_noisify


# MNIST_MEAN = 
MNIST_MEAN = (0.13066048, 0.13066048, 0.13066048)
MNIST_STD = (0.30810780, 0.30810780, 0.30810780)
MNIST_NORMALIZE = transforms.Normalize(MNIST_MEAN, MNIST_STD)

MNIST_M_MEAN = (0.45790559, 0.46209182, 0.40824444)
MNIST_M_STD = (0.25194991, 0.23677807, 0.25870561)
MNIST_M_NORMALIZE = transforms.Normalize(MNIST_M_MEAN, MNIST_M_STD)

USPS_MEAN = (0.24687695, 0.24687695, 0.24687695)
USPS_STD = (0.29887581, 0.29887581, 0.29887581)
USPS_NORMALIZE = transforms.Normalize(USPS_MEAN, USPS_STD)

SVHN_MEAN = (0.43768210, 0.44376970, 0.47280442)
SVHN_STD = (0.19803012, 0.20101562, 0.19703614)
SVHN_NORMALIZE = transforms.Normalize(SVHN_MEAN, SVHN_STD)


def TRAIN_TRANSFORM(normalize):
    return transforms.Compose([
        transforms.Resize([32, 32]),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

def EVAL_TRANSFORM(normalize):
    return transforms.Compose([
        transforms.Resize([32, 32]),
        transforms.ToTensor(),
        normalize
    ])


class DIGIT(torch.utils.data.Dataset):
    def __init__(self, train, data_path, unsupervised_transform=False):
        super().__init__()
        self.mnist_dataset = MNIST(train, os.path.join(data_path, 'MNIST'), unsupervised_transform=unsupervised_transform)
        self.mnist_m_dataset = MNIST_M(train, os.path.join(data_path, 'MNIST_M'), unsupervised_transform=unsupervised_transform)
        self.usps_dataset = USPS(train, os.path.join(data_path, 'USPS'), unsupervised_transform=unsupervised_transform)
        self.svhn_dataset = SVHN(train, os.path.join(data_path, 'SVHN'), unsupervised_transform=unsupervised_transform)
        self.len_mnist = len(self.mnist_dataset)
        self.len_mnist_m = len(self.mnist_m_dataset)
        self.len_usps = len(self.usps_dataset)
        self.len_svhn = len(self.svhn_dataset)
    
    def __getitem__(self, idx):
        if idx < self.len_mnist:
            return self.mnist_dataset.__getitem__(idx)
        idx = idx - self.len_mnist
        if idx <  self.len_mnist_m:
            return self.mnist_m_dataset.__getitem__(idx)
        idx = idx - self.len_mnist_m
        if idx < self.len_usps:
            return self.usps_dataset.__getitem__(idx)
        idx = idx - self.len_usps
        return self.svhn_dataset.__getitem__(idx)
    
    def __len__(self):
        return self.len_mnist + self.len_mnist_m + self.len_usps + self.len_svhn

class DIGIT_TwoCrops(torch.utils.data.Dataset):
    def __init__(self, train, data_path, need_transform_=False):
        super().__init__()
        self.mnist_dataset = MNIST_TwoCrops(train, os.path.join(data_path, 'MNIST'), need_transform_=need_transform_)
        self.mnist_m_dataset = MNIST_M_TwoCrops(train, os.path.join(data_path, 'MNIST_M'), need_transform_=need_transform_)
        self.usps_dataset = USPS_TwoCrops(train, os.path.join(data_path, 'USPS'), need_transform_=need_transform_)
        self.svhn_dataset = SVHN_TwoCrops(train, os.path.join(data_path, 'SVHN'), need_transform_=need_transform_)
        self.len_mnist = len(self.mnist_dataset)
        self.len_mnist_m = len(self.mnist_m_dataset)
        self.len_usps = len(self.usps_dataset)
        self.len_svhn = len(self.svhn_dataset)
    
    def __getitem__(self, idx):
        if idx < self.len_mnist:
            return self.mnist_dataset.__getitem__(idx)
        idx = idx - self.len_mnist
        if idx <  self.len_mnist_m:
            return self.mnist_m_dataset.__getitem__(idx)
        idx = idx - self.len_mnist_m
        if idx < self.len_usps:
            return self.usps_dataset.__getitem__(idx)
        idx = idx - self.len_usps
        return self.svhn_dataset.__getitem__(idx)
    
    def __len__(self):
        return self.len_mnist + self.len_mnist_m + self.len_usps + self.len_svhn



class MNIST(datasets.MNIST):
    def __init__(self, train, data_path, unsupervised_transform=False):
        super(MNIST, self).__init__(data_path, train=train)
        if unsupervised_transform:
            self.transform = get_unsupervised_transform(normalize=MNIST_NORMALIZE)
        else:
            self.transform = TRAIN_TRANSFORM(MNIST_NORMALIZE) if train else EVAL_TRANSFORM(MNIST_NORMALIZE)
    
    def __getitem__(self, idx):
        img, label = self.data[idx], int(self.targets[idx])
        img = Image.fromarray(img.numpy(), mode='L').convert('RGB')
        img = self.transform(img)
        return img, label

class MNIST_TwoCrops(datasets.MNIST):
    def __init__(self, train, data_path, need_transform_=False):
        super(MNIST_TwoCrops, self).__init__(data_path, train=True, download=True)
        self.transform = get_unsupervised_transform(normalize=MNIST_NORMALIZE)
        self.transform_ = EVAL_TRANSFORM(MNIST_NORMALIZE) if need_transform_ else get_unsupervised_transform(normalize=MNIST_NORMALIZE)
    
    def __getitem__(self, idx):
        img, label = self.data[idx], int(self.targets[idx])
        img = Image.fromarray(img.numpy(), mode='L').convert('RGB')
        img1 = self.transform_(img)
        img2 = self.transform(img)
        return (img1, img2), label


class MNIST_M(torch.utils.data.Dataset):
    def __init__(self, train, data_path, unsupervised_transform=False):
        super().__init__()
        self.root = data_path
        img_path = os.path.join(self.root, 'mnist_m_{}'.format('train' if train else 'test'))
        txt = 'mnist_m_{}_labels.txt'.format('train' if train else 'test')
        
        data_list, self.label_list = [], []
        f = open(os.path.join(self.root, txt), 'r')
        lines = f.readlines()
        for line in lines:
            data_list.append(line[:-3])
            self.label_list.append(line[-3:])
        
        self.data = []
        for data_path in data_list:
            data_path = os.path.join(img_path, data_path)
            self.data.append(np.array(Image.open(data_path)))
        
        self.data = np.array(self.data)

        if unsupervised_transform:
            self.transform = get_unsupervised_transform(normalize=MNIST_M_NORMALIZE)
        else:
            self.transform = TRAIN_TRANSFORM(MNIST_M_NORMALIZE) if train else EVAL_TRANSFORM(MNIST_M_NORMALIZE)
    
    def __getitem__(self, idx):
        img = self.data[idx]
        label = int(self.label_list[idx])
        img = Image.fromarray(img)
        img = self.transform(img)
        return img, label
    
    def __len__(self):
        return len(self.label_list)


class MNIST_M_TwoCrops(MNIST_M):
    def __init__(self, train, data_path, need_transform_=False):
        super(MNIST_M_TwoCrops, self).__init__(train, data_path)
        self.transform = get_unsupervised_transform(normalize=MNIST_M_NORMALIZE)
        self.transform_ = EVAL_TRANSFORM(MNIST_M_NORMALIZE) if need_transform_ else get_unsupervised_transform(normalize=MNIST_M_NORMALIZE)
    
    def __getitem__(self, idx):
        img = self.data[idx]
        label = int(self.label_list[idx])
        img = Image.fromarray(img)
        img1 = self.transform_(img)
        img2 = self.transform(img)
        return (img1, img2), label


class USPS(torch.utils.data.Dataset):
    def __init__(self, train, data_path, unsupervised_transform=False):
        super().__init__()
        filename = 'usps.bz2' if train else 'usps.t.bz2'
        import bz2
        with bz2.open(os.path.join(data_path, filename)) as fp:
            raw_data = [line.decode().split() for line in fp.readlines()]
            tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
            imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
            imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
            targets = [int(d[0]) - 1 for d in raw_data]
        self.data = imgs
        self.targets = targets

        if unsupervised_transform:
            self.transform = get_unsupervised_transform(normalize=USPS_NORMALIZE)
        else:
            self.transform = TRAIN_TRANSFORM(USPS_NORMALIZE) if train else EVAL_TRANSFORM(USPS_NORMALIZE)

    
    def __getitem__(self, idx):
        img, label = self.data[idx], int(self.targets[idx])
        img = Image.fromarray(img, mode='L').convert('RGB')
        img = self.transform(img)
        return img, label
    
    def __len__(self):
        return len(self.data)


class USPS_TwoCrops(USPS):
    def __init__(self, train, data_path, need_transform_=False):
        super(USPS_TwoCrops, self).__init__(train, data_path)
        self.transform = get_unsupervised_transform(normalize=USPS_NORMALIZE)
        self.transform_ = EVAL_TRANSFORM(USPS_NORMALIZE) if need_transform_ else get_unsupervised_transform(normalize=USPS_NORMALIZE)
    
    def __getitem__(self, idx):
        img, label = self.data[idx], int(self.targets[idx])
        img = Image.fromarray(img, mode='L').convert('RGB')
        img1 = self.transform_(img)
        img2 = self.transform(img)
        return (img1, img2), label



class SVHN(torch.utils.data.Dataset):
    url = ""
    filename = ""
    file_md5 = ""

    split_list = {
        'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
                  "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
        'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
                 "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
    }

    def __init__(self, train, data_path, unsupervised_transform=False):
        super(SVHN, self).__init__()
        split = 'train' if train else 'test'

        self.url = self.split_list[split][0]
        self.filename = self.split_list[split][1]
        self.file_md5 = self.split_list[split][2]

        # import here rather than at top of file because this is
        # an optional dependency for torchvision
        import scipy.io as sio

        # reading(loading) mat file as array
        loaded_mat = sio.loadmat(os.path.join(data_path, self.filename))

        self.data = loaded_mat['X']
        self.targets = loaded_mat['y']
        # Note label 10 == 0 so modulo operator required
        self.targets = (self.targets % 10).squeeze()    # convert to zero-based indexing
        self.data = np.transpose(self.data, (3, 2, 0, 1))

        
        # self.transform = SVHN_EVAL_TRAINSFORM if not unsupervised_transform else get_unsupervised_transform(normalize=SVHN_NORMALIZE)
        if unsupervised_transform:
            self.transform = get_unsupervised_transform(normalize=SVHN_NORMALIZE)
        else:
            self.transform = TRAIN_TRANSFORM(SVHN_NORMALIZE) if train else EVAL_TRANSFORM(SVHN_NORMALIZE)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(np.transpose(img, (1, 2, 0)))

        if self.transform is not None:
            img = self.transform(img)

        return img, target

    def __len__(self):
        return len(self.data)



class SVHN_TwoCrops(SVHN):
    def __init__(self, train, data_path, need_transform_=False):
        super(SVHN_TwoCrops, self).__init__(train, data_path)
        self.transform = get_unsupervised_transform(normalize=SVHN_NORMALIZE)
        self.transform_ = EVAL_TRANSFORM(SVHN_NORMALIZE) if need_transform_ else get_unsupervised_transform(normalize=SVHN_NORMALIZE)
    
    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.targets[idx]
        image = Image.fromarray(np.transpose(data, (1, 2, 0)))
        image1 = self.transform_(image)
        image2 = self.transform(image)
        return (image1, image2), label



class NoisySVHN(SVHN):
    def __init__(self, train, data_path, unsupervised_transform=False, noise_rate=0.0, is_asym=False, seed=0):
        super(NoisySVHN, self).__init__(train, data_path, unsupervised_transform=unsupervised_transform)
        np.random.seed(seed)
        if is_asym:
            P = np.eye(10)
            n = noise_rate
            # 7 -> 1
            P[7, 7], P[7, 1] = 1. - n, n
            # 2 -> 7
            P[2, 2], P[2, 7] = 1. - n, n
            # 5 <-> 6
            P[5, 5], P[5, 6] = 1. - n, n
            P[6, 6], P[6, 5] = 1. - n, n
            # 3 -> 8
            P[3, 3], P[3, 8] = 1. - n, n

            y_train_noisy = multiclass_noisify(self.targets, P=P, random_state=seed)
            actual_noise = (y_train_noisy != self.targets).mean()
            assert actual_noise > 0.0
            logger.info('Actual noise %.2f' % actual_noise)
            self.targets = y_train_noisy
        elif noise_rate > 0:
            P = np.ones((10, 10))
            n = noise_rate
            P = (n / (10 - 1)) * P
            if n > 0.0:
                n_samples = len(self.targets)
                P[0, 0] = 1. - n
                for i in range(1, 10 - 1):
                    P[i, i] = 1. - n
                P[10 - 1, 10 - 1] = 1. - n

                y_train_noisy = multiclass_noisify(self.targets, P=P, random_state=seed)
                actual_noise = (y_train_noisy != self.targets).mean()
                assert actual_noise > 0.0
                logger.info('Actual noise %.2f' % actual_noise)
                self.targets = y_train_noisy
        logger.info('Print noisy label generation statistics:')
        for i in range(10):
            n_noisy = np.sum(np.array(self.targets) == i)
            logger.info("Noisy class %s, has %s samples." % (i, n_noisy))

class NoisySVHN_TwoCrops(NoisySVHN):
    def __init__(self, train, data_path, unsupervised_transform=False, need_transform_=False, noise_rate=0.0, is_asym=False):
        super(NoisySVHN_TwoCrops, self).__init__(train, data_path, unsupervised_transform=unsupervised_transform, noise_rate=noise_rate, is_asym=is_asym)
        self.transform = get_unsupervised_transform(normalize=SVHN_NORMALIZE)
        self.transform_ = EVAL_TRANSFORM(SVHN_NORMALIZE) if need_transform_ else get_unsupervised_transform(normalize=SVHN_NORMALIZE)
    
    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.targets[idx]
        image = Image.fromarray(np.transpose(data, (1, 2, 0)))
        image1 = self.transform_(image)
        image2 = self.transform(image)
        return (image1, image2), label